# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates provider factories."""

load("@bazel_skylib//lib:structs.bzl", "structs")
load("@rules_testing//lib:truth.bzl", "subjects")

visibility("private")

def generate_factory(type, name, attrs):
    """Generates a factory for a custom struct.

    There are three reasons we need to do so:
    1. It's very difficult to read providers printed by these types.
        eg. If you have a 10 layer deep diamond dependency graph, and try to
        print the top value, the bottom value will be printed 2^10 times.
    2. Collections of subjects are not well supported by rules_testing
        eg. `FeatureInfo(flag_sets = [FlagSetInfo(...)])`
        (You can do it, but the inner values are just regular bazel structs and
        you can't do fluent assertions on them).
    3. Recursive types are not supported at all
        eg. `FeatureInfo(implies = depset([FeatureInfo(...)]))`

    To solve this, we create a factory that:
    * Validates that the types of the children are correct.
    * Inlines providers to their labels when unambiguous.

    For example, given:

    ```
    foo = FeatureInfo(name = "foo", label = Label("//:foo"))
    bar = FeatureInfo(..., implies = depset([foo]))
    ```

    It would convert itself a subject for the following struct:
    `FeatureInfo(..., implies = depset([Label("//:foo")]))`

    Args:
        type: (type) The type to create a factory for (eg. FooInfo)
        name: (str) The name of the type (eg. "FooInfo")
        attrs: (dict[str, Factory]) The attributes associated with this type.

    Returns:
        A struct `FooFactory` suitable for use with
        * `analysis_test(provider_subject_factories=[FooFactory])`
        * `generate_factory(..., attrs=dict(foo = FooFactory))`
        * `ProviderSequence(FooFactory)`
        * `DepsetSequence(FooFactory)`
    """
    attrs["label"] = subjects.label

    want_keys = sorted(attrs.keys())

    def validate(*, value, meta):
        if value == None:
            meta.add_failure("Wanted a %s but got" % name, value)
        got_keys = sorted(structs.to_dict(value).keys())
        subjects.collection(got_keys, meta = meta.derive(details = [
            "Value %r was not a %s - it has a different set of fields" % (value, name),
        ])).contains_exactly(want_keys).in_order()

    def type_factory(value, *, meta):
        validate(value = value, meta = meta)

        transformed_value = {}
        transformed_factories = {}
        for field, factory in attrs.items():
            field_value = getattr(value, field)

            # If it's a type generated by generate_factory, inline it.
            if hasattr(factory, "factory"):
                factory.validate(value = field_value, meta = meta.derive(field))
                transformed_value[field] = field_value.label
                transformed_factories[field] = subjects.label
            else:
                transformed_value[field] = field_value
                transformed_factories[field] = factory

        return subjects.struct(
            struct(**transformed_value),
            meta = meta,
            attrs = transformed_factories,
        )

    return struct(
        type = type,
        name = name,
        factory = type_factory,
        validate = validate,
    )

def _provider_collection(element_factory, fn):
    def factory(value, *, meta):
        value = fn(value)

        # Validate that it really is the correct type
        for i in range(len(value)):
            element_factory.validate(
                value = value[i],
                meta = meta.derive("offset({})".format(i)),
            )

        # Inline the providers to just labels.
        return subjects.collection([v.label for v in value], meta = meta)

    return factory

# This acts like a class, so we name it like one.
# buildifier: disable=name-conventions
ProviderSequence = lambda element_factory: _provider_collection(
    element_factory,
    fn = lambda x: list(x),
)

# buildifier: disable=name-conventions
ProviderDepset = lambda element_factory: _provider_collection(
    element_factory,
    fn = lambda x: x.to_list(),
)
